{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-tuning Models Using the DPO Algorithm\n",
    "\n",
    "This tutorial demonstrates how to fine-tune large models using the DPO algorithm (using the Llama-3.1-8B model as an example). Through this tutorial, you will learn how to configure training parameters and perform reinforcement learning-style training on preference-labeled data using the DPO algorithm to improve model performance on alignment tasks.\n",
    "\n",
    "## 1. What is the DPO Algorithm?\n",
    "\n",
    "DPO (Direct Preference Optimization) is a method for training language models to better align with human preferences. It does not rely on explicit reward models or policy gradient methods, but directly optimizes the model on “human preference data,” making it prefer the human-preferred response when given two answers.\n",
    "\n",
    "## 2. Environment Setup\n",
    "\n",
    "Before starting, please make sure you have installed the ``align-anything`` package.\n",
    "\n",
    "```bash\n",
    "# Clone the repository\n",
    "git clone git@github.com:PKU-Alignment/align-anything.git\n",
    "cd align-anything\n",
    "\n",
    "# Create a virtual environment using conda\n",
    "conda create -n align-anything python==3.11\n",
    "conda activate align-anything\n",
    "```\n",
    "\n",
    "- **`[Optional]`** We recommend installing [CUDA](https://anaconda.org/nvidia/cuda) in the conda environment and set the environment variable.\n",
    "\n",
    "```bash\n",
    "# We have tested this version of CUDA on the H800 computing cluster and it worked well.\n",
    "# You can adjust this version according to your actual computing cluster.\n",
    "\n",
    "conda install nvidia/label/cuda-12.2.0::cuda\n",
    "export CUDA_HOME=$CONDA_PREFIX\n",
    "```\n",
    "\n",
    "> If your CUDA is installed in a different location, such as `/usr/local/cuda/bin/nvcc`, you can set the environment variable as follows:\n",
    "\n",
    "```bash\n",
    "export CUDA_HOME=\"/usr/local/cuda\"\n",
    "```\n",
    "\n",
    "Finally, install `align-anything` using the following command:\n",
    "\n",
    "```bash\n",
    "# We have prepared a quick installation for training and evaluation.\n",
    "# If you only need to use the training or evaluation module,\n",
    "# you can install the corresponding dependencies.\n",
    "pip install -e .[train] # Install training dependencies\n",
    "pip install -e .[evaluate] # Install evaluation dependencies\n",
    "\n",
    "# If you need to install all dependencies, you can use the following command:\n",
    "pip install -e .[all]\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Llama-3.1-8B-Instruct Model Output Example\n",
    "Next, let's first test the zero-shot capability of the Llama-3.1-8B-Instruct model.\n",
    "\n",
    "### 3.1 Import Required Libraries\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1742778596.498488] [dsw-519274-66f65ff576-678dh:4051137:f]        vfs_fuse.c:281  UCX  ERROR inotify_add_watch(/tmp) failed: No space left on device\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import os\n",
    "import torch\n",
    "\n",
    "os.environ[\"TRANSFORMERS_OFFLINE\"] = \"1\"\n",
    "os.environ[\"HF_DATASETS_OFFLINE\"] = \"1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Load the Original Llama Model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.29it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(128256, 4096)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x LlamaDecoderLayer(\n",
       "        (self_attn): LlamaAttention(\n",
       "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "        )\n",
       "        (mlp): LlamaMLP(\n",
       "          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "      )\n",
       "    )\n",
       "    (norm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "    (rotary_emb): LlamaRotaryEmbedding()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = \"cuda\"  # Set device to \"cuda\" to use GPU\n",
    "model_path = (\n",
    "    \"/PATH/TO/YOUR/Meta-Llama-3.1-8B-Instruct\"  # Please replace with your actual model path\n",
    ")\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)\n",
    "\n",
    "# Set the model to evaluation mode\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 Test the Performance of the Original Model\n",
    "\n",
    "Let's test the Llama-3.1-8B-Instruct model with a sample question.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Generated Text: If a wild animal in your local area has become aggressive and caused injuries, it's essential to take precautions and follow the right steps to ensure your safety and the safety of others. Here's a step-by-step guide:\n",
      "\n",
      "1.  **Stay calm**: Keep a safe distance from the animal and avoid direct confrontation. Panicking can escalate the situation, and you don't want to provoke the animal further.\n",
      "\n",
      "2.  **Identify the animal**: If possible, try to determine the type of animal and its size, as this information will be helpful for wildlife experts or local authorities.\n",
      "\n",
      "3.  **Contact local authorities**: Reach out to local animal control, wildlife services, or a professional wildlife removal service. They will send trained experts to handle the situation.\n",
      "\n",
      "4.  **Keep children and pets indoors**: Ensure that children and pets are safely indoors, away from the area where the animal is present.\n",
      "\n",
      "5.  **Do not approach or feed the animal**: Feeding or approaching the animal can make it more aggressive and increase the risk of injury.\n",
      "\n",
      "6.  **Secure your home**: If the animal is in your yard or near your home, consider securing any food sources, such as pet food, bird seed, or trash, which may be attracting the animal.\n",
      "\n",
      "7.  **Keep a safe distance**: If you're in the same area as the animal, maintain a safe distance and avoid direct eye contact.\n",
      "\n",
      "8.  **Follow instructions from authorities**: If local authorities or wildlife experts arrive, follow their instructions carefully. They may need you to stay away from the area or provide additional information to help them safely remove the animal.\n",
      "\n",
      "9.  **Report any injuries**: If you or someone else has been injured by the animal, seek medical attention immediately and report the incident to local authorities.\n",
      "\n",
      "10. **Prevent future encounters**: If the animal is a recurring problem, consider taking steps to prevent future encounters, such as securing your home, removing attractants, and installing wildlife-proof fencing.\n",
      "\n",
      "Remember, it's always best to prioritize your safety and the safety of others when dealing with aggressive wild animals.\n"
     ]
    }
   ],
   "source": [
    "messages = [\n",
    "    {\"role\": \"system\", \"content\": \"You are a helpful assistant that answers user queries.\"},\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"Recently, a wild animal in the local area has become aggressive towards humans and caused several injuries. How should I handle this wild animal?\",\n",
    "    },\n",
    "]\n",
    "\n",
    "input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "inputs = tokenizer([input_text], return_tensors=\"pt\").to(device)\n",
    "\n",
    "# the model generate new tokens\n",
    "with torch.no_grad():\n",
    "    output = model.generate(**inputs, max_new_tokens=2048)\n",
    "# convert the generated tokens to text\n",
    "generated_text = tokenizer.decode(\n",
    "    output[0][len(inputs['input_ids'][0]) :], skip_special_tokens=True\n",
    ")\n",
    "print(\"\\nGenerated Text:\", generated_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This shows that although Llama 3.1 provides detailed content in its responses, there are issues such as information redundancy and insufficient emphasis on key risks.\n",
    "\n",
    "For example, it suggests \"identifying the animal\" without explicitly warning to stay away from dangerous areas. This could mislead people into approaching for observation, thereby increasing the risk of injury and negatively impacting emergency safety responses in critical situations.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Aligning the Model Using the DPO Algorithm\n",
    "\n",
    "**Note**: If you cannot access huggingface.co, set the Hugging Face endpoint to hf-mirror.com. You can do this with the following command:\n",
    "\n",
    "`export HF_ENDPOINT=\"https://hf-mirror.com\"`\n",
    "\n",
    "Here, we take the PKU-SafeRLHF series dataset as an example. The PKU-SafeRLHF dataset is a preference dataset focused on safety alignment. Each data entry in this dataset contains two responses to the same question, along with their corresponding safety meta-tags and preference annotations.\n",
    "\n",
    "You can refer to the training script below:\n",
    "\n",
    "```bash\n",
    "MODEL_NAME_OR_PATH=\"meta-llama/Llama-3.1-8B-Instruct\" # model path\n",
    "\n",
    "TRAIN_DATASETS=\"PKU-Alignment/PKU-SafeRLHF-single-dimension\" # dataset path\n",
    "TRAIN_TEMPLATE=\"PKUSafeRLHF\" # dataset template\n",
    "TRAIN_SPLIT=\"train\" # split the dataset\n",
    "\n",
    "OUTPUT_DIR=\"../outputs/llama_dpo\" # output dir\n",
    "\n",
    "# For wandb online logging\n",
    "export WANDB_API_KEY=\"YOUR_API_KEY\"\n",
    "\n",
    "# Source the setup script\n",
    "source ./setup.sh\n",
    "\n",
    "# Execute deepspeed command\n",
    "deepspeed \\\n",
    "     --master_port ${MASTER_PORT} \\\n",
    "     --module align_anything.trainers.text_to_text.dpo \\\n",
    "     --model_name_or_path ${MODEL_NAME_OR_PATH} \\\n",
    "     --train_template ${TRAIN_TEMPLATE} \\\n",
    "     --train_datasets ${TRAIN_DATASETS} \\\n",
    "     --train_split ${TRAIN_SPLIT} \\\n",
    "     --output_dir ${OUTPUT_DIR}\n",
    "```\n",
    "\n",
    "After training is completed, you can find the trained model weights under the `OUTPUT_DIR`.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Test the Performance of the Model After DPO Training\n",
    "\n",
    "After the training is complete, we try to test whether the alignment of the trained model has improved.\n",
    "\n",
    "### 5.1 Load the New Model Weights\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(128257, 4096, padding_idx=128256)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x LlamaDecoderLayer(\n",
       "        (self_attn): LlamaAttention(\n",
       "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "        )\n",
       "        (mlp): LlamaMLP(\n",
       "          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "      )\n",
       "    )\n",
       "    (norm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "    (rotary_emb): LlamaRotaryEmbedding()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=128257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_path = \"/PATH/TO/YOUR/llama_dpo/slice_end\"  # Please replace with your actual model path\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)\n",
    "\n",
    "# Set the model to evaluation mode\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 Test the Performance of the New Model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Generated Text: If a wild animal has become aggressive and is causing injuries, it's essential to exercise extreme caution.  First, make sure all people and pets are kept away from the area.  Next, contact a professional wildlife removal expert or local animal control service to safely capture and relocate the animal.  In the meantime, try to determine what may have caused the animal's behavior change, such as habitat loss or food availability, and take steps to mitigate those factors. Finally, educate people in the area on how to peacefully coexist with the animal and what precautions should be taken when interacting with it.\n"
     ]
    }
   ],
   "source": [
    "messages = [\n",
    "    {\"role\": \"system\", \"content\": \"You are a helpful assistant that answers user queries.\"},\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"Recently, a wild animal in the local area has become aggressive towards humans and caused several injuries. How should I handle this wild animal?\",\n",
    "    },\n",
    "]\n",
    "\n",
    "input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "inputs = tokenizer([input_text], return_tensors=\"pt\").to(device)\n",
    "\n",
    "# the model generate new tokens\n",
    "with torch.no_grad():\n",
    "    output = model.generate(**inputs, max_new_tokens=2048)\n",
    "# convert the generated tokens to text\n",
    "generated_text = tokenizer.decode(\n",
    "    output[0][len(inputs['input_ids'][0]) :], skip_special_tokens=True\n",
    ")\n",
    "print(\"\\nGenerated Text:\", generated_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This shows that the responses of the trained model are more concise and focused on key safety measures.\n",
    "\n",
    "Phrases like \"stay away,\" \"contact professionals,\" \"analyze the cause,\" and \"public education\" reflect a human-centered, risk-prevention approach that minimizes direct contact, aligning better with the principles of safety alignment.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Acknowledgements\n",
    "\n",
    "- [Hugging Face Transformers Documentation](https://huggingface.co/docs/transformers/index)\n",
    "- [DPO Paper](https://arxiv.org/abs/2305.18290)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jy-align",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
